import numpy as np
import matplotlib.pylab as plt
import torch
data=np.array([
    [1,2,2,3],
    [4,5,5,6],
    [4,5,5,6],
    [7,8,8,9]
])
shift_x = torch.roll(torch.from_numpy(data), shifts=(-1, -1), dims=(0, 1))
plt.matshow(data)
plt.savefig('2.jpg')
plt.matshow(shift_x.numpy())
plt.savefig('3.jpg')
plt.show()